# Copyright 2019 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from . import sparse_conv_ext


def get_conv_output_size(input_size, kernel_size, stride, padding, dilation):
    ndim = len(input_size)
    output_size = []
    for i in range(ndim):
        size = (input_size[i] + 2 * padding[i] - dilation[i] * (kernel_size[i] - 1) - 1) // stride[
            i
        ] + 1
        if kernel_size[i] == -1:
            output_size.append(1)
        else:
            output_size.append(size)
    return output_size


def get_deconv_output_size(input_size, kernel_size, stride, padding, dilation, output_padding):
    ndim = len(input_size)
    output_size = []
    for i in range(ndim):
        if kernel_size[i] == -1:
            raise ValueError("deconv don't support kernel_size < 0")
        size = (input_size[i] - 1) * stride[i] - 2 * padding[i] + kernel_size[i] + output_padding[i]
        output_size.append(size)
    return output_size


def get_indice_pairs(
    indices,
    batch_size,
    spatial_shape,
    ksize=3,
    stride=1,
    padding=0,
    dilation=1,
    out_padding=0,
    subm=False,
    transpose=False,
    grid=None,
):
    ndim = indices.shape[1] - 1
    if not isinstance(ksize, (list, tuple)):
        ksize = [ksize] * ndim
    if not isinstance(stride, (list, tuple)):
        stride = [stride] * ndim
    if not isinstance(padding, (list, tuple)):
        padding = [padding] * ndim
    if not isinstance(dilation, (list, tuple)):
        dilation = [dilation] * ndim
    if not isinstance(out_padding, (list, tuple)):
        out_padding = [out_padding] * ndim

    for d, s in zip(dilation, stride):
        assert any([s == 1, d == 1]), "don't support this."

    if not subm:
        if transpose:
            out_shape = get_deconv_output_size(
                spatial_shape, ksize, stride, padding, dilation, out_padding
            )
        else:
            out_shape = get_conv_output_size(spatial_shape, ksize, stride, padding, dilation)

    else:
        out_shape = spatial_shape
    if grid is None:
        if ndim == 2:
            get_indice_pairs_func = sparse_conv_ext.get_indice_pairs_2d
        elif ndim == 3:
            get_indice_pairs_func = sparse_conv_ext.get_indice_pairs_3d
        elif ndim == 4:
            get_indice_pairs_func = sparse_conv_ext.get_indice_pairs_4d
        else:
            raise NotImplementedError
        return get_indice_pairs_func(
            indices,
            batch_size,
            out_shape,
            spatial_shape,
            ksize,
            stride,
            padding,
            dilation,
            out_padding,
            int(subm),
            int(transpose),
        )
    else:
        if ndim == 2:
            get_indice_pairs_func = sparse_conv_ext.get_indice_pairs_grid_2d
        elif ndim == 3:
            get_indice_pairs_func = sparse_conv_ext.get_indice_pairs_grid_3d
        else:
            raise NotImplementedError
        return get_indice_pairs_func(
            indices,
            grid,
            batch_size,
            out_shape,
            spatial_shape,
            ksize,
            stride,
            padding,
            dilation,
            out_padding,
            int(subm),
            int(transpose),
        )


def indice_conv(
    features, filters, indice_pairs, indice_pair_num, num_activate_out, inverse=False, subm=False
):
    if filters.dtype == torch.float32:
        return sparse_conv_ext.indice_conv_fp32(
            features,
            filters,
            indice_pairs,
            indice_pair_num,
            num_activate_out,
            int(inverse),
            int(subm),
        )
    elif filters.dtype == torch.half:
        return sparse_conv_ext.indice_conv_half(
            features,
            filters,
            indice_pairs,
            indice_pair_num,
            num_activate_out,
            int(inverse),
            int(subm),
        )
    else:
        raise NotImplementedError


def fused_indice_conv(
    features, filters, bias, indice_pairs, indice_pair_num, num_activate_out, inverse, subm
):
    if features.dtype == torch.half:
        func = sparse_conv_ext.fused_indice_conv_half
    elif filters.dtype == torch.float32:
        func = sparse_conv_ext.fused_indice_conv_fp32
    else:
        raise NotImplementedError

    return func(
        features,
        filters,
        bias,
        indice_pairs,
        indice_pair_num,
        num_activate_out,
        int(inverse),
        int(subm),
    )


def indice_conv_backward(
    features, filters, out_bp, indice_pairs, indice_pair_num, inverse=False, subm=False
):
    if filters.dtype == torch.float32:
        return sparse_conv_ext.indice_conv_backward_fp32(
            features, filters, out_bp, indice_pairs, indice_pair_num, int(inverse), int(subm)
        )
    elif filters.dtype == torch.half:
        return sparse_conv_ext.indice_conv_backward_half(
            features, filters, out_bp, indice_pairs, indice_pair_num, int(inverse), int(subm)
        )
    else:
        raise NotImplementedError


def indice_maxpool(features, indice_pairs, indice_pair_num, num_activate_out):
    if features.dtype == torch.float32:
        return sparse_conv_ext.indice_maxpool_fp32(
            features, indice_pairs, indice_pair_num, num_activate_out
        )
    elif features.dtype == torch.half:
        return sparse_conv_ext.indice_maxpool_half(
            features, indice_pairs, indice_pair_num, num_activate_out
        )
    else:
        raise NotImplementedError


def indice_maxpool_backward(features, out_features, out_bp, indice_pairs, indice_pair_num):
    if features.dtype == torch.float32:
        return sparse_conv_ext.indice_maxpool_backward_fp32(
            features, out_features, out_bp, indice_pairs, indice_pair_num
        )
    elif features.dtype == torch.half:
        return sparse_conv_ext.indice_maxpool_backward_half(
            features, out_features, out_bp, indice_pairs, indice_pair_num
        )
    else:
        raise NotImplementedError
